{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bce1951a-8eef-4842-a70f-987b85a3240f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# installation\n",
    "!python3 -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu\n",
    "!python3 -m pip install tokenizers -U\n",
    "!python3 -m pip install transformers -U"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9769e847-d838-473d-8d32-1061b3e0f1c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# go to maxtext/MaxText for library import\n",
    "\n",
    "current_dir = %pwd\n",
    "working_dir = current_dir.replace(\"scratch_code\", \"\")\n",
    "%cd $working_dir"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f1c108fc-d739-471d-9c64-c08151845f06",
   "metadata": {},
   "source": [
    "# one layer mixtral model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf8eee59-295e-41f4-8c09-d2177b410ddc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os.path\n",
    "import pyconfig\n",
    "from transformers.models.mixtral.configuration_mixtral import MixtralConfig\n",
    "from MaxText.globals import MAXTEXT_PKG_DIR\n",
    "\n",
    "config_maxtext = pyconfig.initialize(\n",
    "    [None, os.path.join(MAXTEXT_PKG_DIR, \"configs\", \"base.yml\")],\n",
    "    base_emb_dim=4096,\n",
    "    base_num_query_heads=32,\n",
    "    base_num_kv_heads=8,\n",
    "    base_mlp_dim=14336,\n",
    "    base_num_decoder_layers=1,  # 1 layer for simplicity\n",
    "    head_dim=128,\n",
    "    mlp_activations=[\"silu\", \"linear\"],\n",
    "    vocab_size=32000,\n",
    "    enable_dropout=False,\n",
    "    logits_via_embedding=False,\n",
    "    normalization_layer_epsilon=1.0e-5,\n",
    "    num_experts=8,\n",
    "    num_experts_per_tok=2,\n",
    "    rope_max_timescale=1_000_000,\n",
    "    decoder_block=\"mistral\",\n",
    "    run_name=\"moe_test\",\n",
    "    enable_checkpointing=False,\n",
    "    dtype=\"bfloat16\",\n",
    "    weight_dtype=\"bfloat16\",\n",
    "    megablox=True,  # or False\n",
    "    max_target_length=4,\n",
    "    max_prefill_predict_length=3,\n",
    "    per_device_batch_size=1,\n",
    "    capacity_factor=-1,\n",
    "    scan_layers=False,\n",
    ")\n",
    "\n",
    "config_hf = MixtralConfig(\n",
    "    vocab_size=config_maxtext.vocab_size,\n",
    "    hidden_size=config_maxtext.emb_dim,\n",
    "    intermediate_size=config_maxtext.mlp_dim,\n",
    "    num_hidden_layers=config_maxtext.num_decoder_layers,\n",
    "    num_attention_heads=config_maxtext.base_num_query_heads,\n",
    "    num_key_value_heads=config_maxtext.num_kv_heads,\n",
    "    rms_norm_eps=config_maxtext.normalization_layer_epsilon,\n",
    "    rope_theta=config_maxtext.rope_max_timescale,\n",
    "    attention_dropout=0.0,\n",
    "    num_experts_per_tok=config_maxtext.num_experts_per_tok,\n",
    "    num_local_experts=config_maxtext.num_experts,\n",
    "    tie_word_embeddings=config_maxtext.logits_via_embedding,\n",
    "    output_router_logits=False,\n",
    "    router_aux_loss_coef=0.001,\n",
    "    router_jitter_noise=0.0,\n",
    "    torch_dtype=\"bfloat16\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c94c857a-2efd-48f3-9669-aef926329cbd",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForCausalLM, set_seed\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "from MaxText.layers.models import Transformer\n",
    "from MaxText import maxtext_utils\n",
    "from jax.sharding import Mesh\n",
    "\n",
    "# ensure the same model initialization\n",
    "set_seed(0)\n",
    "\n",
    "model_hf = AutoModelForCausalLM.from_config(config_hf)\n",
    "\n",
    "devices_array = maxtext_utils.create_device_mesh(config_maxtext)\n",
    "mesh = Mesh(devices_array, config_maxtext.mesh_axes)\n",
    "prng_key = jax.random.PRNGKey(1234)\n",
    "model_maxtext = Transformer(config=config_maxtext, mesh=mesh, quant=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "707df022-ec37-44b3-b203-5f938151c6ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "input_np = {\n",
    "    \"inputs\": np.random.randint(\n",
    "        0, config_maxtext.vocab_size, size=(int(config_maxtext.per_device_batch_size), config_maxtext.max_target_length)\n",
    "    ),\n",
    "    \"inputs_position\": np.tile(\n",
    "        np.arange(config_maxtext.max_target_length), (int(config_maxtext.per_device_batch_size), 1)\n",
    "    ),\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "baca50fb-28f2-48b1-b4f5-0145ac6cfe38",
   "metadata": {},
   "outputs": [],
   "source": [
    "state_maxtext = model_maxtext.init(\n",
    "    {\"params\": prng_key, \"dropout\": prng_key, \"aqt\": prng_key},\n",
    "    jnp.array(input_np[\"inputs\"]),\n",
    "    jnp.array(input_np[\"inputs_position\"]),\n",
    "    enable_dropout=config_maxtext.enable_dropout,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74e8353b-b87a-4c5e-9a7c-138052249250",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from flax import linen as nn\n",
    "\n",
    "state_map = {\n",
    "    \"['params']['decoder']['decoder_norm']['scale'].value\": (\"model.norm.weight\", lambda x: x),\n",
    "    \"['params']['decoder']['layers_0']['MoeBlock_0']['gate']['kernel'].value\": (\n",
    "        \"model.layers.0.block_sparse_moe.gate.weight\",\n",
    "        lambda x: x.T,\n",
    "    ),\n",
    "    \"['params']['decoder']['layers_0']['MoeBlock_0']['wi_0'].value\": (\n",
    "        \"model.layers.0.block_sparse_moe.experts.<exp_idx>.w1.weight\",\n",
    "        lambda *x: torch.stack(*x, dim=0).transpose(1, 2),\n",
    "    ),\n",
    "    \"['params']['decoder']['layers_0']['MoeBlock_0']['wi_1'].value\": (\n",
    "        \"model.layers.0.block_sparse_moe.experts.<exp_idx>.w3.weight\",\n",
    "        lambda *x: torch.stack(*x, dim=0).transpose(1, 2),\n",
    "    ),\n",
    "    \"['params']['decoder']['layers_0']['MoeBlock_0']['wo'].value\": (\n",
    "        \"model.layers.0.block_sparse_moe.experts.<exp_idx>.w2.weight\",\n",
    "        lambda *x: torch.stack(*x, dim=0).transpose(1, 2),\n",
    "    ),\n",
    "    \"['params']['decoder']['layers_0']['post_self_attention_layer_norm']['scale'].value\": (\n",
    "        \"model.layers.0.post_attention_layernorm.weight\",\n",
    "        lambda x: x,\n",
    "    ),\n",
    "    \"['params']['decoder']['layers_0']['pre_self_attention_layer_norm']['scale'].value\": (\n",
    "        \"model.layers.0.input_layernorm.weight\",\n",
    "        lambda x: x,\n",
    "    ),\n",
    "    \"['params']['decoder']['layers_0']['self_attention']['key']['kernel'].value\": (\n",
    "        \"model.layers.0.self_attn.k_proj.weight\",\n",
    "        lambda x: x.T.reshape(config_hf.hidden_size, config_hf.num_key_value_heads, config_maxtext.head_dim),\n",
    "    ),\n",
    "    \"['params']['decoder']['layers_0']['self_attention']['out']['kernel'].value\": (\n",
    "        \"model.layers.0.self_attn.o_proj.weight\",\n",
    "        lambda x: x.T.reshape(config_hf.num_attention_heads, config_maxtext.head_dim, config_hf.hidden_size),\n",
    "    ),\n",
    "    \"['params']['decoder']['layers_0']['self_attention']['query']['kernel'].value\": (\n",
    "        \"model.layers.0.self_attn.q_proj.weight\",\n",
    "        lambda x: x.T.reshape(config_hf.hidden_size, config_hf.num_attention_heads, config_maxtext.head_dim)\n",
    "        / np.sqrt(config_maxtext.head_dim),\n",
    "    ),\n",
    "    \"['params']['decoder']['layers_0']['self_attention']['value']['kernel'].value\": (\n",
    "        \"model.layers.0.self_attn.v_proj.weight\",\n",
    "        lambda x: x.T.reshape(config_hf.hidden_size, config_hf.num_key_value_heads, config_maxtext.head_dim),\n",
    "    ),\n",
    "    \"['params']['decoder']['logits_dense']['kernel'].value\": (\"lm_head.weight\", lambda x: x.T),\n",
    "    \"['params']['token_embedder']['embedding'].value\": (\"model.embed_tokens.weight\", lambda x: x),\n",
    "}\n",
    "\n",
    "state_hf = model_hf.state_dict()\n",
    "\n",
    "\n",
    "def map_fn(key_path, value):\n",
    "  key_path_str = jax.tree_util.keystr(key_path)\n",
    "  torch_key, transform_fn = state_map[key_path_str]\n",
    "  if \"<exp_idx>\" in torch_key:\n",
    "    torch_tensors = [state_hf[torch_key.replace(\"<exp_idx>\", str(i))] for i in range(config_hf.num_local_experts)]\n",
    "  else:\n",
    "    torch_tensors = state_hf[torch_key]\n",
    "\n",
    "  torch_tensors = transform_fn(torch_tensors)\n",
    "\n",
    "  assert value.shape == torch_tensors.shape, f\"{key_path_str}, {value.shape}, {torch_tensors.shape}\"\n",
    "  new_value = jnp.array(torch_tensors.to(torch.float32).numpy(), dtype=value.dtype)\n",
    "  if isinstance(value, nn.LogicallyPartitioned):\n",
    "    new_value = value.replace_boxed(new_value)\n",
    "  return new_value\n",
    "\n",
    "\n",
    "loaded_state_maxtext = jax.tree_util.tree_map_with_path(map_fn, state_maxtext)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1f88708-c3a6-4b95-bc51-94adfebdf2aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "logits_hf = model_hf(torch.from_numpy(input_np[\"inputs\"])).logits.detach()\n",
    "\n",
    "logits_maxtext = model_maxtext.apply(\n",
    "    loaded_state_maxtext,\n",
    "    input_np[\"inputs\"],\n",
    "    input_np[\"inputs_position\"],\n",
    "    enable_dropout=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1207375a-b92c-4a8c-975a-21f2f027d91e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# currently, pass the following tests in both \"megablox=True\" & \"megablox=False capacity_factor=-1\"\n",
    "\n",
    "np.testing.assert_allclose(np.array(logits_maxtext), logits_hf.numpy(), rtol=1e-1, atol=1e-1)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
